import glob
import os
from typing import Tuple
import numpy as np
import sys
import pandas as pd 

sys.path.append('./')
from classes.data.datasets.BaseDataset import BaseDataset

#1.mate30 2.P30pro 3.iphone 4.vivo 5.Xiaomi11

class CTADataset(BaseDataset):

    def __init__(self, mode: str = "train", input_size: Tuple = (224, 224), device: int = 1):
        super().__init__(mode, input_size, device)
        #path_to_dataset = os.path.join("dataset", "tcc", "preprocessed", split_folder)
        dataset_device = ['HuaweiMate30', 'HuaweiP30PRO', 'iphone14pm', 'vivoiqooneo5', 'Xiaomi11PRO']
        num_device = ['mate30', 'P30pro', 'iphonepm', 'vivo', 'xiaomi11pro']
        path_to_dataset = './NPY1/' + dataset_device[device-1] +'/'
        if mode == 'train':
            train_path = './TAWB/dataset/CTA-Set/train_'+num_device[device-1]+'.npy'
            train_info = np.load(train_path, allow_pickle=True).item()
            train_ids = train_info['id']
            train_nums = train_info['num']
            for i in range(len(train_ids)):
                id = train_ids[i]
                self._paths_to_seqs.append(path_to_dataset + str(id))
            self._nums_to_seqs = train_nums
        # path_to_data = os.path.join(path_to_dataset, self._train_dir)
        # self._paths_to_seqs = glob.glob(os.path.join(path_to_data, "{}*.npy".format(mode)))
        #self._paths_to_seqs.sort(key=lambda x: int(x.split(mode)[-1][:-4]))

# training_set = TCSRDataset(mode="test")
# training_set.__getitem__(1)

class TrainAllDataset(BaseDataset):

    def __init__(self, mode: str = "train", input_size: Tuple = (224, 224), device: int = 1):
        super().__init__(mode, input_size, device)
        #path_to_dataset = os.path.join("dataset", "tcc", "preprocessed", split_folder)
        dataset_device = ['HuaweiMate30', 'HuaweiP30PRO', 'iphone14pm', 'vivoiqooneo5', 'Xiaomi11PRO']
        num_device = ['mate30', 'P30pro', 'iphonepm', 'vivo', 'xiaomi11pro']
        if mode == 'train':
            for device in range(1, 6):
                path_to_dataset = './NPY1/' + dataset_device[device-1] +'/'
                train_path = './TAWB/dataset/CTA-Set/train_'+num_device[device-1]+'.npy'
                train_info = np.load(train_path, allow_pickle=True).item()
                train_ids = train_info['id']
                train_nums = train_info['num']
                for i in range(len(train_ids)):
                    id = train_ids[i]
                    num = train_nums[i]
                    self._paths_to_seqs.append(path_to_dataset + str(id))
                    self._nums_to_seqs.append(num)